{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DATA Saved\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import os\n",
    "import radialProfile\n",
    "import glob\n",
    "from matplotlib import pyplot as plt\n",
    "import pickle\n",
    "\n",
    "path = ['thispersondoesntexists_10K/', '100KFake_10K/','Flickr-Faces-HQ_10K/', 'celebA-HQ/']\n",
    "labels = [1,1,0,0]\n",
    "epsilon = 1e-8\n",
    "data = {}\n",
    "#number of samples from each dataset\n",
    "stop = 250\n",
    "number_iter = 4 * stop\n",
    "psd1D_total = np.zeros([number_iter, 722])\n",
    "label_total = np.zeros([number_iter])\n",
    "iter_ = 0\n",
    "\n",
    "for z in range(4):\n",
    "    \n",
    "    psd1D_average_org = np.zeros(722)\n",
    "    \n",
    "    for filename in glob.glob(path[z]+\"/*.jpg\"):\n",
    "        \n",
    "        img = cv2.imread(filename,0)\n",
    "        # Calculate FFT\n",
    "        f = np.fft.fft2(img)\n",
    "        fshift = np.fft.fftshift(f)\n",
    "        fshift += epsilon\n",
    "        magnitude_spectrum = 20*np.log(np.abs(fshift))\n",
    "        # Calculate the azimuthally averaged 1D power spectrum\n",
    "        psd1D = radialProfile.azimuthalAverage(magnitude_spectrum)\n",
    "\n",
    "        psd1D_total[iter_,:] = psd1D\n",
    "        label_total[iter_] = labels[z]\n",
    "\n",
    "        iter_+=1\n",
    "        if iter_ >= stop:\n",
    "            break\n",
    "\n",
    "data[\"data\"] = psd1D_total\n",
    "data[\"label\"] = label_total\n",
    "\n",
    "output = open('dataset_freq_1000.pkl', 'wb')\n",
    "pickle.dump(data, output)\n",
    "output.close()\n",
    "print(\"DATA Saved\")    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average SVM: 1.0\n",
      "Average LR: 1.0\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "#number of runs\n",
    "num = 5\n",
    "SVM = 0\n",
    "LR = 0\n",
    "\n",
    "for z in range(num):\n",
    "    # read python dict back from the file\n",
    "    pkl_file = open('dataset_freq_1000.pkl', 'rb')\n",
    "    \n",
    "    data = pickle.load(pkl_file)\n",
    "    pkl_file.close()\n",
    "    X = data[\"data\"]\n",
    "    y = data[\"label\"]\n",
    "    X.shape\n",
    "    try:\n",
    "\n",
    "        from sklearn.model_selection import train_test_split\n",
    "        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)\n",
    "\n",
    "        from sklearn.svm import SVC\n",
    "        svclassifier = SVC(kernel='linear')\n",
    "        svclassifier.fit(X_train, y_train)\n",
    "        #print('Accuracy on test set: {:.3f}'.format(svclassifier.score(X_test, y_test)))\n",
    "               \n",
    "        \n",
    "        from sklearn.linear_model import LogisticRegression\n",
    "        logreg = LogisticRegression(solver='liblinear', max_iter=1000)\n",
    "        logreg.fit(X_train, y_train)\n",
    "        #print('Accuracy on test set: {:.3f}'.format(logreg.score(X_test, y_test)))\n",
    "        \n",
    "        SVM+=svclassifier.score(X_train, y_train)\n",
    "        LR+=logreg.score(X_test, y_test)\n",
    "        \n",
    "    except:\n",
    "        num-=1\n",
    "        print(num)\n",
    "    \n",
    "\n",
    "print(\"Average SVM: \"+str(SVM/num))\n",
    "print(\"Average LR: \"+str(LR/num))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
